import llamahf
import torch
import pandas as pd
from torch.utils.data import Dataset, random_split
from transformers import TrainingArguments, Trainer

# # to save memory use bfloat16 on cpu
# torch.set_default_dtype(torch.bfloat16)

MODEL = 'decapoda-research/llama-7b-hf'
DATA_FILE_PATH = 'datasets/elon_musk_tweets.csv'
OUTPUT_DIR = './trained'

texts = pd.read_csv(DATA_FILE_PATH)['text']

tokenizer = llamahf.LLaMATokenizer.from_pretrained(MODEL)
model = llamahf.LLaMAForCausalLM.from_pretrained(MODEL).cpu()


class TextDataset(Dataset):
    def __init__(self, txt_list, tokenizer, max_length):
        self.labels = []
        self.input_ids = []
        self.attn_masks = []
        for txt in txt_list:
            # encodings_dict = tokenizer(txt, truncation=True, max_length=max_length, padding="max_length")
            encodings_dict = tokenizer(txt, truncation=True, max_length=max_length, pad_to_max_length=False)
            self.input_ids.append(torch.tensor(encodings_dict['input_ids']))
            self.attn_masks.append(torch.tensor(encodings_dict['attention_mask']))

    def __len__(self): return len(self.input_ids)

    def __getitem__(self, idx): return self.input_ids[idx], self.attn_masks[idx]


dataset = TextDataset(texts, tokenizer, max_length=max([len(tokenizer.encode(text)) for text in texts]))
train_dataset, val_dataset = random_split(dataset, [int(0.9 * len(dataset)), len(dataset) - int(0.9 * len(dataset))])

training_args = TrainingArguments(
    save_steps=5000,
    warmup_steps=10,
    logging_steps=100,
    weight_decay=0.05,
    num_train_epochs=1,
    logging_dir='./logs',
    output_dir=OUTPUT_DIR,
    no_cuda=True,
    per_device_eval_batch_size=1,
    per_device_train_batch_size=1)

trainer = Trainer(model=model,
                  args=training_args,
                  eval_dataset=val_dataset,
                  train_dataset=train_dataset,
                  data_collator=lambda data: {'input_ids': torch.stack([f[0] for f in data]),
                                              'attention_mask': torch.stack([f[1] for f in data]),
                                              'labels': torch.stack([f[0] for f in data])})

trainer.train()
trainer.save_model()
tokenizer.save_pretrained(OUTPUT_DIR)
del trainer

sample_outputs = model.generate(tokenizer('', return_tensors="pt").input_ids.cpu(),
                                do_sample=True,
                                top_k=50,
                                max_length=300,
                                top_p=0.95,
                                temperature=1.0)

print(tokenizer.decode(sample_outputs[0]))
